{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Tce3stUlHN0L"
   },
   "source": [
    "##### Copyright 2018 The TensorFlow Authors.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:41.565069Z",
     "iopub.status.busy": "2022-12-14T22:37:41.564631Z",
     "iopub.status.idle": "2022-12-14T22:37:41.568355Z",
     "shell.execute_reply": "2022-12-14T22:37:41.567807Z"
    },
    "id": "tuOe1ymfHZPu"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MfBg1C5NB3X0"
   },
   "source": [
    "# 使用 TensorFlow 进行分布式训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r6P32iYYV27b"
   },
   "source": [
    "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
    "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/guide/distributed_training\" class=\"\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" class=\"\">在 TensorFlow.org 上查看</a></td>\n",
    "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/distributed_training.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行 </a></td>\n",
    "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/distributed_training.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 上查看源代码</a></td>\n",
    "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/guide/distributed_training.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\">下载笔记本</a></td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xHxb-dlhMIzW"
   },
   "source": [
    "## 概述\n",
    "\n",
    "`tf.distribute.Strategy` 是一个可在多个 GPU、多台机器或 TPU 上进行分布式训练的 TensorFlow API。使用此 API，您只需改动较少代码就能分布现有模型和训练代码。\n",
    "\n",
    "`tf.distribute.Strategy` 旨在实现以下目标：\n",
    "\n",
    "- 易于使用，支持多种用户（包括研究人员和机器学习工程师等）。\n",
    "- 提供开箱即用的良好性能。\n",
    "- 轻松切换策略。\n",
    "\n",
    "您可以将 `tf.distribute.Strategy` 与 Keras `Model.fit` 之类的高级 API 以及[自定义训练循环](https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch)（通常使用 TensorFlow 来进行计算）结合使用来分布训练。\n",
    "\n",
    "在 TensorFlow 2.x 中，您可以在 Eager 模式下执行程序，也可以使用 [`tf.function`](function.ipynb) 在计算图中执行。虽然 `tf.distribute.Strategy` 对两种执行模式都支持，但使用 `tf.function` 效果最佳。建议仅将 Eager 模式用于调试，而 `TPUStrategy` 不支持此模式。尽管本指南大部分时间在讨论训练，但此 API 也可用于在不同平台上分布评估和预测。\n",
    "\n",
    "您在使用 `tf.distribute.Strategy` 时只需改动少量代码，因为我们修改了 TensorFlow 的底层组件，使其可感知策略。这些组件包括变量、层、优化器、指标、摘要和检查点。\n",
    "\n",
    "在本指南中，您将了解各种类型的策略以及如何在不同情况下使用它们。要了解如何调试性能问题，请参阅[优化 TensorFlow GPU 性能](gpu_performance_analysis.md)指南。\n",
    "\n",
    "注：要更深入地了解这些概念，请观看深入演示 [Inside TensorFlow：`tf.distribute.Strategy`](https://youtu.be/jKV53r9-H14)。如果您打算编写自己的训练循环，则特别推荐这样做。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "b3600ee25c8e"
   },
   "source": [
    "## 设置 TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:41.571889Z",
     "iopub.status.busy": "2022-12-14T22:37:41.571329Z",
     "iopub.status.idle": "2022-12-14T22:37:43.508308Z",
     "shell.execute_reply": "2022-12-14T22:37:43.507621Z"
    },
    "id": "EVOZFbNgXghB"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-12-14 22:37:42.513051: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
      "2022-12-14 22:37:42.513148: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
      "2022-12-14 22:37:42.513159: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eQ1QESxxEbCh"
   },
   "source": [
    "## 策略类型\n",
    "\n",
    "`tf.distribute.Strategy` 打算涵盖不同轴上的许多用例。目前已支持其中的部分组合，将来还会添加其他组合。其中一些轴包括：\n",
    "\n",
    "- *同步和异步训练*：这是通过数据并行进行分布式训练的两种常用方法。在同步训练中，所有工作进程都同步地对输入数据的不同片段进行训练，并且会在每一步中聚合梯度。在异步训练中，所有工作进程都独立训练输入数据并异步更新变量。通常情况下，同步训练通过全归约实现，而异步训练通过参数服务器架构实现。\n",
    "- *硬件平台*：您可能需要将训练扩展到一台机器上的多个 GPU 或一个网络中的多台机器（每台机器拥有 0 个或多个 GPU），或扩展到 Cloud TPU 上。\n",
    "\n",
    "为了支持这些用例，TensorFlow 提供了 `MirroredStrategy`、`TPUStrategy`、`MultiWorkerMirroredStrategy`、`ParameterServerStrategy`、`CentralStorageStrategy`，以及其他可用策略。下一部分将具体说明 TensorFlow 中的哪些场景支持上述策略。以下是快速概览：\n",
    "\n",
    "训练 API | `MirroredStrategy` | `TPUStrategy` | `MultiWorkerMirroredStrategy` | `CentralStorageStrategy` | `ParameterServerStrategy`\n",
    ":-- | :-- | :-- | :-- | :-- | :--\n",
    "**Keras `Model.fit`** | 支持 | 支持 | 支持 | 实验性支持 | 实验性支持\n",
    "**自定义训练循环** | 支持 | 支持 | 支持 | 实验性支持 | 实验性支持\n",
    "**Estimator API** | 有限支持 | 不支持 | 有限支持 | 有限支持 | 有限支持\n",
    "\n",
    "注：[实验性支持](https://tensorflow.google.cn/guide/versions#what_is_not_covered)是指兼容性保证不涵盖这些 API。\n",
    "\n",
    "警告：Estimator 为有限支持。基本训练和评估为实验性支持，而基架等高级功能未得到实现。如果未涵盖用例，则应使用 Keras 或自定义训练循环。不建议将 Estimator 用于新代码。Estimator 运行 `v1.Session` 风格的代码，此类代码更加难以正确编写，并且可能会出现意外行为，尤其是与 TF 2 代码结合使用时。Estimator 确实在我们的[兼容性保证](https://tensorflow.org/guide/versions)范围内，但除了安全漏洞之外不会得到任何修复。请转到[迁移指南](https://tensorflow.org/guide/migrate)了解详情。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DoQKKK8dtfg6"
   },
   "source": [
    "### MirroredStrategy\n",
    "\n",
    "`tf.distribute.MirroredStrategy` 支持在一台机器的多个 GPU 上进行同步分布式训练。该策略会为每个 GPU 设备创建一个副本。模型中的每个变量都会在所有副本之间进行镜像。这些变量将共同形成一个名为 `MirroredVariable` 的单个概念变量。这些变量会通过应用相同的更新彼此保持同步。\n",
    "\n",
    "高效的全规约算法用于跨设备传达变量更新。全规约通过将所有设备中的张量相加来聚合它们，并使它们在每个设备上都可用。这是一种极其高效的融合算法，可以显著降低同步产生的开销。有许多全规约算法和实现，具体取决于设备之间可用的通信类型。默认情况下，它使用 NVIDIA Collective Communication Library ([NCCL](https://developer.nvidia.com/nccl)) 作为全规约实现。您可以从其他几个选项中进行选择或者编写自己的选项。\n",
    "\n",
    "以下是创建 `MirroredStrategy` 的最简单方式："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:43.512729Z",
     "iopub.status.busy": "2022-12-14T22:37:43.512320Z",
     "iopub.status.idle": "2022-12-14T22:37:47.928000Z",
     "shell.execute_reply": "2022-12-14T22:37:47.927210Z"
    },
    "id": "9Z4FMAY9ADxK"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    }
   ],
   "source": [
    "mirrored_strategy = tf.distribute.MirroredStrategy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wldY4aFCAH4r"
   },
   "source": [
    "这会创建一个 `MirroredStrategy` 实例，该实例使用所有对 TensorFlow 可见的 GPU，并使用 NCCL 进行跨设备通信。\n",
    "\n",
    "如果您只想使用机器上的部分 GPU，您可以这样做："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:47.932014Z",
     "iopub.status.busy": "2022-12-14T22:37:47.931475Z",
     "iopub.status.idle": "2022-12-14T22:37:47.938608Z",
     "shell.execute_reply": "2022-12-14T22:37:47.938040Z"
    },
    "id": "nbGleskCACv_"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')\n"
     ]
    }
   ],
   "source": [
    "mirrored_strategy = tf.distribute.MirroredStrategy(devices=[\"/gpu:0\", \"/gpu:1\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8-KDnrJLAhav"
   },
   "source": [
    "如果您想重写跨设备通信，可以通过提供 `tf.distribute.CrossDeviceOps` 的实例，使用 `cross_device_ops` 参数来实现。目前，除了默认选项 `tf.distribute.NcclAllReduce` 外，还有 `tf.distribute.HierarchicalCopyAllReduce` 和 `tf.distribute.ReductionToOneDevice` 两个选项。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:47.942122Z",
     "iopub.status.busy": "2022-12-14T22:37:47.941486Z",
     "iopub.status.idle": "2022-12-14T22:37:47.947108Z",
     "shell.execute_reply": "2022-12-14T22:37:47.946538Z"
    },
    "id": "6-xIOIpgBItn"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    }
   ],
   "source": [
    "mirrored_strategy = tf.distribute.MirroredStrategy(\n",
    "    cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kPEBCMzsGaO5"
   },
   "source": [
    "### TPUStrategy\n",
    "\n",
    "您可以使用 `tf.distribute.experimental.TPUStrategy` 在张量处理单元 (TPU) 上运行 TensorFlow 训练。TPU 是 Google 的专用 ASIC，旨在显著加速机器学习工作负载。您可通过 Google Colab、[TensorFlow Research Cloud](https://tensorflow.google.cn/tfrc) 和 [Cloud TPU](https://cloud.google.com/tpu) 平台进行使用。\n",
    "\n",
    "就分布式训练架构而言，`TPUStrategy` 和 `MirroredStrategy` 是一样的，即实现同步分布式训练。TPU 会在多个 TPU 核心之间实现高效的全归约和其他集合运算，并将其用于 `TPUStrategy`。\n",
    "\n",
    "下面演示了如何将 `TPUStrategy` 实例化：\n",
    "\n",
    "注：要在 Colab 中运行任何 TPU 代码，应将 TPU 作为 Colab 运行时。请参阅[使用 TPU](tpu.ipynb) 指南获得完整示例。\n",
    "\n",
    "```python\n",
    "cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(\n",
    "    tpu=tpu_address)\n",
    "tf.config.experimental_connect_to_cluster(cluster_resolver)\n",
    "tf.tpu.experimental.initialize_tpu_system(cluster_resolver)\n",
    "tpu_strategy = tf.distribute.TPUStrategy(cluster_resolver)\n",
    "```\n",
    "\n",
    "`TPUClusterResolver` 实例可帮助定位 TPU。在 Colab 中，您无需为其指定任何参数。\n",
    "\n",
    "如果您想要将其用于 Cloud TPU，则必须执行以下操作：\n",
    "\n",
    "- 在 `tpu` 参数中指定 TPU 资源的名称。\n",
    "- 在程序*开始*时显式地初始化 TPU 系统。这是使用 TPU 进行计算前的必需步骤。初始化 TPU 系统还会清除 TPU 内存，所以为了避免丢失状态，请务必先完成此步骤。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8Xc3gyo0Bejd"
   },
   "source": [
    "### MultiWorkerMirroredStrategy\n",
    "\n",
    "`tf.distribute.MultiWorkerMirroredStrategy` 与 `MirroredStrategy` 非常相似。它实现了跨多个工作进程的同步分布式训练，而每个工作进程可能有多个 GPU。与 `MirroredStrategy` 类似，它也会跨所有工作进程在每个设备的模型中创建所有变量的副本。\n",
    "\n",
    "以下是创建 `MultiWorkerMirroredStrategy` 的最简单方式："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:47.950294Z",
     "iopub.status.busy": "2022-12-14T22:37:47.949823Z",
     "iopub.status.idle": "2022-12-14T22:37:47.957716Z",
     "shell.execute_reply": "2022-12-14T22:37:47.957174Z"
    },
    "id": "m3a_6ebbEjre"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO\n"
     ]
    }
   ],
   "source": [
    "strategy = tf.distribute.MultiWorkerMirroredStrategy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bt94JBvhEr4s"
   },
   "source": [
    "`MultiWorkerMirroredStrategy` 有两种用于跨设备通信的实现。`CommunicationImplementation.RING` 基于 [RPC](https://en.wikipedia.org/wiki/Remote_procedure_call)，同时支持 CPU 和 GPU。`CommunicationImplementation.NCCL` 使用 NCCL 并在 GPU 上提供最先进的性能，但它不支持 CPU。`CollectiveCommunication.AUTO` 将选择权交给 Tensorflow。您可以通过以下方式指定它们：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:47.961049Z",
     "iopub.status.busy": "2022-12-14T22:37:47.960459Z",
     "iopub.status.idle": "2022-12-14T22:37:47.967381Z",
     "shell.execute_reply": "2022-12-14T22:37:47.966801Z"
    },
    "id": "QGX_QAEtFQSv"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.NCCL\n"
     ]
    }
   ],
   "source": [
    "communication_options = tf.distribute.experimental.CommunicationOptions(\n",
    "    implementation=tf.distribute.experimental.CommunicationImplementation.NCCL)\n",
    "strategy = tf.distribute.MultiWorkerMirroredStrategy(\n",
    "    communication_options=communication_options)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0JiImlw3F77E"
   },
   "source": [
    "与多 GPU 训练相比，多工作进程训练的一个主要差异是多工作进程的设置。`'TF_CONFIG'` 环境变量是在 TensorFlow 中为作为集群一部分的每个工作进程指定集群配置的标准方式。请在本文档的[设置 TF_CONFIG](#TF_CONFIG) 部分了解详细信息。\n",
    "\n",
    "有关 `MultiWorkerMirroredStrategy` 的详细信息，请参阅以下教程：\n",
    "\n",
    "- [使用 Keras Model.fit 进行多工作进程训练](../tutorials/distribute/multi_worker_with_keras.ipynb)\n",
    "- [使用自定义训练循环进行多工作进程训练](../tutorials/distribute/multi_worker_with_ctl.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3ZLBhaP9NUNr"
   },
   "source": [
    "### ParameterServerStrategy\n",
    "\n",
    "参数服务器训练是一种常见的数据并行方法，用于在多台机器上扩展模型训练。参数服务器训练集群由工作进程和参数服务器组成。变量在参数服务器上创建，并在每个步骤中由工作进程读取和更新。查看[参数服务器培训](../tutorials/distribute/parameter_server_training.ipynb)教程了解详情。\n",
    "\n",
    "在 TensorFlow 2 中，参数服务器训练通过 {code 0}tf.distribute.experimental.coordinator.Cluster Coordinator{/code 0} 类使用基于中央协调器的架构。\n",
    "\n",
    "在此实现中，`worker` 和 `parameter server` 任务运行侦听来自协调器的任务的  `tf.distribute.Server`。协调器创建资源、调度训练任务、编写检查点并处理任务失败。\n",
    "\n",
    "在协调器上运行的编程中，您将使用 `ParameterServerStrategy` 对象定义训练步骤，并使用 `ClusterCoordinator` 将训练步骤分派给远程工作进程。这是创建它们的最简单方式：\n",
    "\n",
    "```python\n",
    "strategy = tf.distribute.experimental.ParameterServerStrategy(\n",
    "    tf.distribute.cluster_resolver.TFConfigClusterResolver(),\n",
    "    variable_partitioner=variable_partitioner)\n",
    "coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(\n",
    "    strategy)\n",
    "```\n",
    "\n",
    "要了解有关 `ParameterServerStrategy` 的详细信息，请参阅[使用 Keras Model.fit 和自定义训练循环进行参数服务器训练](../tutorials/distribute/parameter_server_training.ipynb)教程。\n",
    "\n",
    "注：如果使用 `TFConfigClusterResolver`，则需要配置 `'TF_CONFIG'` 环境变量。它类似于 `MultiWorkerMirroredStrategy` 中的 <a data-md-type=\"raw_html\" href=\"#TF_CONFIG\">`'TF_CONFIG'`</a>，但具有额外的注意事项。\n",
    "\n",
    "在 TensorFlow 1 中，`ParameterServerStrategy`只能通过 `tf.compat.v1.distribute.experimental.ParameterServerStrategy` 符号在 Estimator 中使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "E20tG21LFfv1"
   },
   "source": [
    "注：此策略是 [`experimental`](https://tensorflow.google.cn/guide/versions#what_is_not_covered)，因为它目前正在进行积极开发。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "45H0Wa8WKI8z"
   },
   "source": [
    "### CentralStorageStrategy\n",
    "\n",
    "`tf.distribute.experimental.CentralStorageStrategy` 也执行同步训练。变量不会被镜像，而是放在 CPU 上，且运算会复制到所有本地 GPU 。如果只有一个 GPU，则所有变量和运算都将被放在该 GPU 上。\n",
    "\n",
    "请通过以下代码，创建 `CentralStorageStrategy` 实例：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:47.971042Z",
     "iopub.status.busy": "2022-12-14T22:37:47.970462Z",
     "iopub.status.idle": "2022-12-14T22:37:47.974365Z",
     "shell.execute_reply": "2022-12-14T22:37:47.973772Z"
    },
    "id": "rtjZOyaoMWrP"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3'], variable_device = '/device:CPU:0'\n"
     ]
    }
   ],
   "source": [
    "central_storage_strategy = tf.distribute.experimental.CentralStorageStrategy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KY1nJHNkMl7b"
   },
   "source": [
    "这会创建一个 `CentralStorageStrategy` 实例，该实例将使用所有可见的 GPU 和 CPU。在副本上对变量的更新将先进行聚合，然后再应用于变量。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aAFycYUiNCUb"
   },
   "source": [
    "注：此策略是 [`experimental`](https://tensorflow.google.cn/guide/versions#what_is_not_covered)，因为它目前正在进行开发。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "t2XUdmIxKljq"
   },
   "source": [
    "### 其他策略\n",
    "\n",
    "除上述策略外，还有其他两种策略可能对使用 `tf.distribute` API 进行原型设计和调试有所帮助。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UD5I1beTpc7a"
   },
   "source": [
    "#### 默认策略\n",
    "\n",
    "默认策略是一种分布策略，当作用域内没有显式分布策略时就会出现。此策略会实现 `tf.distribute.Strategy` 接口，但只具有传递功能，不提供实际分布。例如，`Strategy.run(fn)` 只会调用 `fn`。使用该策略编写的代码与未使用任何策略编写的代码完全一样。您可以将其视为“无运算”策略。\n",
    "\n",
    "默认策略是一种单一实例，无法创建它的更多实例。可以在任何显式策略作用域之外使用 `tf.distribute.get_strategy` 来获取它（可用于在显式策略作用域内获取当前策略的相同 API）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:47.978003Z",
     "iopub.status.busy": "2022-12-14T22:37:47.977457Z",
     "iopub.status.idle": "2022-12-14T22:37:47.980474Z",
     "shell.execute_reply": "2022-12-14T22:37:47.979920Z"
    },
    "id": "ibHleFOOmPn9"
   },
   "outputs": [],
   "source": [
    "default_strategy = tf.distribute.get_strategy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EkxPl_5ImLzc"
   },
   "source": [
    "此策略有两个主要用途：\n",
    "\n",
    "- 它允许无条件地编写可感知分发的库代码。例如，在 `tf.keras.optimizers` 中，您可以使用 `tf.distribute.get_strategy`，并用此策略来降低梯度 - 它将始终返回一个策略对象，您可以在该对象上调用 `Strategy.reduce` API。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:47.983868Z",
     "iopub.status.busy": "2022-12-14T22:37:47.983337Z",
     "iopub.status.idle": "2022-12-14T22:37:47.989561Z",
     "shell.execute_reply": "2022-12-14T22:37:47.989048Z"
    },
    "id": "WECeRzUdT6bU"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# In optimizer or other library code\n",
    "# Get currently active strategy\n",
    "strategy = tf.distribute.get_strategy()\n",
    "strategy.reduce(\"SUM\", 1., axis=None)  # reduce some values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JURbH-pUT51B"
   },
   "source": [
    "- 与库代码类似，它可用于编写最终用户的程序以便使用或不使用分布策略，而无需条件逻辑。下面是一个说明了这一点的示例代码段："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:47.992789Z",
     "iopub.status.busy": "2022-12-14T22:37:47.992284Z",
     "iopub.status.idle": "2022-12-14T22:37:48.012921Z",
     "shell.execute_reply": "2022-12-14T22:37:48.012361Z"
    },
    "id": "O4Vmae5jmSE6"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MirroredVariable:{\n",
      "  0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,\n",
      "  1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>,\n",
      "  2: <tf.Variable 'Variable/replica_2:0' shape=() dtype=float32, numpy=1.0>,\n",
      "  3: <tf.Variable 'Variable/replica_3:0' shape=() dtype=float32, numpy=1.0>\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "if tf.config.list_physical_devices('GPU'):\n",
    "  strategy = tf.distribute.MirroredStrategy()\n",
    "else:  # Use the Default Strategy\n",
    "  strategy = tf.distribute.get_strategy()\n",
    "\n",
    "with strategy.scope():\n",
    "  # Do something interesting\n",
    "  print(tf.Variable(1.))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kTzsqN4lmJ0d"
   },
   "source": [
    "#### OneDeviceStrategy\n",
    "\n",
    "`tf.distribute.OneDeviceStrategy` 是一种会将所有变量和计算放在单个指定设备上的策略。\n",
    "\n",
    "```python\n",
    "strategy = tf.distribute.OneDeviceStrategy(device=\"/gpu:0\")\n",
    "```\n",
    "\n",
    "此策略在许多方面与默认策略不同。在默认策略中，与在未使用任何分布策略的情况下运行 TensorFlow 相比，变量布局逻辑保持不变。但是，在使用 `OneDeviceStrategy` 时，在其作用域内创建的所有变量都会显式放置在指定的设备上。此外，通过 `OneDeviceStrategy.run` 调用的任何函数也将放置在指定的设备上。\n",
    "\n",
    "通过此策略分布的输入将被预获取到指定设备。在默认策略中，没有输入分布。\n",
    "\n",
    "与默认策略类似，在切换到实际分布到多个设备/机器的其他策略之前，也可以使用此策略来测试代码。这将比默认策略更多地使用分布策略机制，但不能像使用 `MirroredStrategy` 或 `TPUStrategy` 等策略那样充分发挥其作用。如果您想让代码表现地像没有策略，请使用默认策略。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hQv1lm9UPDFy"
   },
   "source": [
    "到目前为止，您已经看到了不同策略以及如何将它们实例化。接下来的几个部分将介绍可以使用它们来分布训练的不同方式。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_mcuy3UhPcen"
   },
   "source": [
    "## 在 Keras Model.fit 中使用 tf.distribute.Strategy\n",
    "\n",
    "`tf.distribute.Strategy` 已集成到 `tf.keras` 中，后者是 TensorFlow 对 [Keras API 规范](https://keras.io)的实现。`tf.keras` 是用于构建和训练模型的高级 API。通过集成到 `tf.keras` 后端，您可以无缝使用 <code>Model.fit</code> 来分布以 Keras 训练框架编写的训练。\n",
    "\n",
    "您需要对代码进行以下更改：\n",
    "\n",
    "1. 创建一个合适的 `tf.distribute.Strategy` 实例。\n",
    "2. 将 Keras 模型、优化器和指标的创建移到 `strategy.scope` 中。因此，模型的 `call()`、`train_step()` 和 `test_step()` 方法中的代码都将在加速器上分布和执行。\n",
    "\n",
    "TensorFlow 分布策略支持所有类型的 Keras 模型 - 序贯、函数式和子类化。\n",
    "\n",
    "下面是一段代码，执行该代码会创建一个非常简单的带有一个 `Dense` 层的 Keras 模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:48.016291Z",
     "iopub.status.busy": "2022-12-14T22:37:48.015788Z",
     "iopub.status.idle": "2022-12-14T22:37:48.118300Z",
     "shell.execute_reply": "2022-12-14T22:37:48.117706Z"
    },
    "id": "gbbcpzRnPZ6V"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    }
   ],
   "source": [
    "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "with mirrored_strategy.scope():\n",
    "  model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])\n",
    "\n",
    "model.compile(loss='mse', optimizer='sgd')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "773EOxCRVlTg"
   },
   "source": [
    "此示例使用 `MirroredStrategy`，因此您可以在具有多个 GPU 的计算机上运行。`strategy.scope()` 指示 Keras 使用哪种策略来分布训练。通过在此作用域内创建模型/优化器/指标，您可以创建分布式变量而不是常规变量。设置完成后，您可以像往常一样拟合模型。`MirroredStrategy` 负责在可用 GPU 上复制模型的训练、聚合梯度等。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:48.121762Z",
     "iopub.status.busy": "2022-12-14T22:37:48.121281Z",
     "iopub.status.idle": "2022-12-14T22:37:55.304721Z",
     "shell.execute_reply": "2022-12-14T22:37:55.303934Z"
    },
    "id": "ZMmxEFRTEjH5"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-12-14 22:37:48.151848: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: \"TensorDataset/_2\"\n",
      "op: \"TensorDataset\"\n",
      "input: \"Placeholder/_0\"\n",
      "input: \"Placeholder/_1\"\n",
      "attr {\n",
      "  key: \"Toutput_types\"\n",
      "  value {\n",
      "    list {\n",
      "      type: DT_FLOAT\n",
      "      type: DT_FLOAT\n",
      "    }\n",
      "  }\n",
      "}\n",
      "attr {\n",
      "  key: \"_cardinality\"\n",
      "  value {\n",
      "    i: 1\n",
      "  }\n",
      "}\n",
      "attr {\n",
      "  key: \"metadata\"\n",
      "  value {\n",
      "    s: \"\\n\\017TensorDataset:0\"\n",
      "  }\n",
      "}\n",
      "attr {\n",
      "  key: \"output_shapes\"\n",
      "  value {\n",
      "    list {\n",
      "      shape {\n",
      "        dim {\n",
      "          size: 1\n",
      "        }\n",
      "      }\n",
      "      shape {\n",
      "        dim {\n",
      "          size: 1\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "experimental_type {\n",
      "  type_id: TFT_PRODUCT\n",
      "  args {\n",
      "    type_id: TFT_DATASET\n",
      "    args {\n",
      "      type_id: TFT_PRODUCT\n",
      "      args {\n",
      "        type_id: TFT_TENSOR\n",
      "        args {\n",
      "          type_id: TFT_FLOAT\n",
      "        }\n",
      "      }\n",
      "      args {\n",
      "        type_id: TFT_TENSOR\n",
      "        args {\n",
      "          type_id: TFT_FLOAT\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      " 1/10 [==>...........................] - ETA: 40s - loss: 5.7945"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "10/10 [==============================] - 5s 5ms/step - loss: 4.1241\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      " 1/10 [==>...........................] - ETA: 0s - loss: 2.5612"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "10/10 [==============================] - 0s 5ms/step - loss: 1.8229\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-12-14 22:37:53.511054: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: \"TensorDataset/_2\"\n",
      "op: \"TensorDataset\"\n",
      "input: \"Placeholder/_0\"\n",
      "input: \"Placeholder/_1\"\n",
      "attr {\n",
      "  key: \"Toutput_types\"\n",
      "  value {\n",
      "    list {\n",
      "      type: DT_FLOAT\n",
      "      type: DT_FLOAT\n",
      "    }\n",
      "  }\n",
      "}\n",
      "attr {\n",
      "  key: \"_cardinality\"\n",
      "  value {\n",
      "    i: 1\n",
      "  }\n",
      "}\n",
      "attr {\n",
      "  key: \"metadata\"\n",
      "  value {\n",
      "    s: \"\\n\\017TensorDataset:0\"\n",
      "  }\n",
      "}\n",
      "attr {\n",
      "  key: \"output_shapes\"\n",
      "  value {\n",
      "    list {\n",
      "      shape {\n",
      "        dim {\n",
      "          size: 1\n",
      "        }\n",
      "      }\n",
      "      shape {\n",
      "        dim {\n",
      "          size: 1\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "experimental_type {\n",
      "  type_id: TFT_PRODUCT\n",
      "  args {\n",
      "    type_id: TFT_DATASET\n",
      "    args {\n",
      "      type_id: TFT_PRODUCT\n",
      "      args {\n",
      "        type_id: TFT_TENSOR\n",
      "        args {\n",
      "          type_id: TFT_FLOAT\n",
      "        }\n",
      "      }\n",
      "      args {\n",
      "        type_id: TFT_TENSOR\n",
      "        args {\n",
      "          type_id: TFT_FLOAT\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      " 1/10 [==>...........................] - ETA: 15s - loss: 1.1320"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "10/10 [==============================] - 2s 4ms/step - loss: 1.1320\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1.1320464611053467"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(10)\n",
    "model.fit(dataset, epochs=2)\n",
    "model.evaluate(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nofTLwyXWHK8"
   },
   "source": [
    "我们在这里使用了 `tf.data.Dataset` 来提供训练和评估输入。您还可以使用 Numpy 数组："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:55.308622Z",
     "iopub.status.busy": "2022-12-14T22:37:55.307954Z",
     "iopub.status.idle": "2022-12-14T22:37:55.675032Z",
     "shell.execute_reply": "2022-12-14T22:37:55.674134Z"
    },
    "id": "Lqgd9SdxW5OW"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      " 1/10 [==>...........................] - ETA: 0s - loss: 1.1320"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "10/10 [==============================] - 0s 5ms/step - loss: 0.8057\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      " 1/10 [==>...........................] - ETA: 0s - loss: 0.5004"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "10/10 [==============================] - 0s 5ms/step - loss: 0.3561\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7efe903a8f40>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "inputs, targets = np.ones((100, 1)), np.ones((100, 1))\n",
    "model.fit(inputs, targets, epochs=2, batch_size=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IKqaj7QwX0Zb"
   },
   "source": [
    "在上述两个示例（`Dataset` 或 Numpy）中，给定输入的每个批次都被平均分到了多个副本中。例如，如果对 2 个 GPU 使用 `MirroredStrategy`，大小为 10 的每个批次将被均分到 2 个 GPU，每个 GPU 会在每步接收 5 个输入样本。如果添加更多 GPU，则每个周期的训练速度会更快。通常，您希望在添加更多加速器时增加批次大小，以便有效利用额外的计算能力。您还需要根据模型重新调整您的学习速率。您可以使用 `strategy.num_replicas_in_sync` 获得副本数量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:55.678796Z",
     "iopub.status.busy": "2022-12-14T22:37:55.678256Z",
     "iopub.status.idle": "2022-12-14T22:37:55.683552Z",
     "shell.execute_reply": "2022-12-14T22:37:55.682695Z"
    },
    "id": "8ZmJqErtS4A1"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mirrored_strategy.num_replicas_in_sync"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:55.686807Z",
     "iopub.status.busy": "2022-12-14T22:37:55.686338Z",
     "iopub.status.idle": "2022-12-14T22:37:55.693419Z",
     "shell.execute_reply": "2022-12-14T22:37:55.692578Z"
    },
    "id": "quNNTytWdGBf"
   },
   "outputs": [],
   "source": [
    "# Compute a global batch size using a number of replicas.\n",
    "BATCH_SIZE_PER_REPLICA = 5\n",
    "global_batch_size = (BATCH_SIZE_PER_REPLICA *\n",
    "                     mirrored_strategy.num_replicas_in_sync)\n",
    "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n",
    "dataset = dataset.batch(global_batch_size)\n",
    "\n",
    "LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15, 20:0.175}\n",
    "learning_rate = LEARNING_RATES_BY_BATCH_SIZE[global_batch_size]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z1Muy0gDZwO5"
   },
   "source": [
    "### 目前支持的策略\n",
    "\n",
    "训练 API | `MirroredStrategy` | `TPUStrategy` | `MultiWorkerMirroredStrategy` | `ParameterServerStrategy` | `CentralStorageStrategy`\n",
    "--- | --- | --- | --- | --- | ---\n",
    "Keras `Model.fit` | 支持 | 支持 | 实验性支持 | 实验性支持 | 实验性支持\n",
    "\n",
    "### 示例和教程\n",
    "\n",
    "下面列出了说明上述与 Keras `Model.fit` 端到端集成的教程和示例：\n",
    "\n",
    "1. [教程](../tutorials/distribute/keras.ipynb)：使用 `Model.fit` 和 `MirroredStrategy` 进行训练。\n",
    "2. [教程](../tutorials/distribute/multi_worker_with_keras.ipynb)：使用 `Model.fit` 和 `MultiWorkerMirroredStrategy` 进行训练。\n",
    "3. [指南](tpu.ipynb)：包含使用`Model.fit` 和 `TPUStrategy` 的示例。\n",
    "4. [教程](../tutorials/distribute/parameter_server_training.ipynb)：使用`Model.fit` 和 `ParameterServerStrategy` 进行参数服务器训练。\n",
    "5. [教程](https://tensorflow.google.cn/text/tutorials/bert_glue)：使用 `Model.fit` 和 `TPUStrategy` 微调 GLUE 基准测试中的许多任务的 BERT。\n",
    "6. 包含使用各种策略实现的最先进模型集合的 TensorFlow Model Garden [仓库](https://github.com/tensorflow/models/tree/master/official)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IlYVC0goepdk"
   },
   "source": [
    "## 在自定义训练循环中使用 `tf.distribute.Strategy`\n",
    "\n",
    "如上所述，在 Keras `Model.fit` 中使用 `tf.distribute.Strategy` 只需改动几行代码。再多花点功夫，您还可以在[自定义训练循环](/keras/writing_a_training_loop_from_scratch.ipynb)中使用 `tf.distribute.Strategy`。\n",
    "\n",
    "如果您需要更多相对于使用 Estimator 或 Keras 时的灵活性和对训练循环的控制权，您可以编写自定义训练循环。例如，在使用 GAN 时，您可能会希望每轮使用不同数量的生成器或判别器步骤。同样，高级框架也不太适合强化学习训练。\n",
    "\n",
    "`tf.distribute.Strategy` 类提供了一组核心方法来支持自定义训练循环。使用这些方法时，最初可能需要对代码进行少量重构，但是一旦完成，您便能够通过更改策略实例在 GPU、TPU 和多台计算机之间切换。\n",
    "\n",
    "下面是用来说明此用例的一个简短的代码段，其中的简单训练样本使用与之前相同的 Keras 模型。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XNHvSY32nVBi"
   },
   "source": [
    "首先，在该策略的作用域内创建模型和优化器。这样可以确保使用模型和优化器创建的任何变量都是镜像变量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:55.697021Z",
     "iopub.status.busy": "2022-12-14T22:37:55.696502Z",
     "iopub.status.idle": "2022-12-14T22:37:55.734054Z",
     "shell.execute_reply": "2022-12-14T22:37:55.733275Z"
    },
    "id": "W-3Bn-CaiPKD"
   },
   "outputs": [],
   "source": [
    "with mirrored_strategy.scope():\n",
    "  model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])\n",
    "  optimizer = tf.keras.optimizers.SGD()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mYkAyPeYnlXk"
   },
   "source": [
    "接下来，创建输入数据集并调用 `tf.distribute.Strategy.experimental_distribute_dataset` 以根据策略分布数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:55.737719Z",
     "iopub.status.busy": "2022-12-14T22:37:55.737085Z",
     "iopub.status.idle": "2022-12-14T22:37:55.754896Z",
     "shell.execute_reply": "2022-12-14T22:37:55.754146Z"
    },
    "id": "94BkvkLInkKd"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-12-14 22:37:55.749655: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: \"TensorDataset/_2\"\n",
      "op: \"TensorDataset\"\n",
      "input: \"Placeholder/_0\"\n",
      "input: \"Placeholder/_1\"\n",
      "attr {\n",
      "  key: \"Toutput_types\"\n",
      "  value {\n",
      "    list {\n",
      "      type: DT_FLOAT\n",
      "      type: DT_FLOAT\n",
      "    }\n",
      "  }\n",
      "}\n",
      "attr {\n",
      "  key: \"_cardinality\"\n",
      "  value {\n",
      "    i: 1\n",
      "  }\n",
      "}\n",
      "attr {\n",
      "  key: \"metadata\"\n",
      "  value {\n",
      "    s: \"\\n\\021TensorDataset:126\"\n",
      "  }\n",
      "}\n",
      "attr {\n",
      "  key: \"output_shapes\"\n",
      "  value {\n",
      "    list {\n",
      "      shape {\n",
      "        dim {\n",
      "          size: 1\n",
      "        }\n",
      "      }\n",
      "      shape {\n",
      "        dim {\n",
      "          size: 1\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "experimental_type {\n",
      "  type_id: TFT_PRODUCT\n",
      "  args {\n",
      "    type_id: TFT_DATASET\n",
      "    args {\n",
      "      type_id: TFT_PRODUCT\n",
      "      args {\n",
      "        type_id: TFT_TENSOR\n",
      "        args {\n",
      "          type_id: TFT_FLOAT\n",
      "        }\n",
      "      }\n",
      "      args {\n",
      "        type_id: TFT_TENSOR\n",
      "        args {\n",
      "          type_id: TFT_FLOAT\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(1000).batch(\n",
    "    global_batch_size)\n",
    "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "grzmTlSvn2j8"
   },
   "source": [
    "然后，定义一个训练步骤。使用 `tf.GradientTape` 计算梯度并使用优化器应用这些梯度来更新模型的变量。要分布此训练步骤，将其放入函数 `train_step` 中，然后将其与您从之前创建的 `dist_dataset` 中获得的数据集输入一起传递给 `tf.distribute.Strategy.run`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:55.758021Z",
     "iopub.status.busy": "2022-12-14T22:37:55.757708Z",
     "iopub.status.idle": "2022-12-14T22:37:55.763642Z",
     "shell.execute_reply": "2022-12-14T22:37:55.762845Z"
    },
    "id": "NJxL5YrVniDe"
   },
   "outputs": [],
   "source": [
    "loss_object = tf.keras.losses.BinaryCrossentropy(\n",
    "  from_logits=True,\n",
    "  reduction=tf.keras.losses.Reduction.NONE)\n",
    "\n",
    "def compute_loss(labels, predictions):\n",
    "  per_example_loss = loss_object(labels, predictions)\n",
    "  return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size)\n",
    "\n",
    "def train_step(inputs):\n",
    "  features, labels = inputs\n",
    "\n",
    "  with tf.GradientTape() as tape:\n",
    "    predictions = model(features, training=True)\n",
    "    loss = compute_loss(labels, predictions)\n",
    "\n",
    "  gradients = tape.gradient(loss, model.trainable_variables)\n",
    "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "  return loss\n",
    "\n",
    "@tf.function\n",
    "def distributed_train_step(dist_inputs):\n",
    "  per_replica_losses = mirrored_strategy.run(train_step, args=(dist_inputs,))\n",
    "  return mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,\n",
    "                         axis=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yRL5u_NLoTvq"
   },
   "source": [
    "以上代码还需注意以下几点：\n",
    "\n",
    "1. 您使用了 `tf.nn.compute_average_loss` 来计算损失。`tf.nn.compute_average_loss` 将每个样本的损失相加，然后将总和除以 `global_batch_size`。这很重要，因为稍后在每个副本上计算出梯度后，会通过对它们**求和**使其在副本中聚合。\n",
    "2. 您还使用了 `tf.distribute.Strategy.reduce` API 来聚合 `tf.distribute.Strategy.run` 返回的结果。`tf.distribute.Strategy.run` 会从策略中的每个本地副本返回结果，您可以通过多种方式使用此结果。可以 `reduce` 它们以获得聚合值。还可以通过执行 `tf.distribute.Strategy.experimental_local_results` 获得包含在结果中的值的列表，每个本地副本一个列表。\n",
    "3. 当在一个分布策略作用域内调用 `apply_gradients` 时，它的行为会被修改。具体来说，在同步训练期间，在将梯度应用于每个并行实例之前，它会对梯度的所有副本求和。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "o9k_6-6vpQ-P"
   },
   "source": [
    "最后，当我们定义完训练步骤后，就可以迭代 `dist_dataset`，并在循环中运行训练："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:55.767147Z",
     "iopub.status.busy": "2022-12-14T22:37:55.766531Z",
     "iopub.status.idle": "2022-12-14T22:37:59.369126Z",
     "shell.execute_reply": "2022-12-14T22:37:59.368389Z"
    },
    "id": "Egq9eufToRf6"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(1.4526782, shape=(), dtype=float32)\n",
      "tf.Tensor(1.4409623, shape=(), dtype=float32)\n",
      "tf.Tensor(1.429331, shape=(), dtype=float32)\n",
      "tf.Tensor(1.417784, shape=(), dtype=float32)\n",
      "tf.Tensor(1.4063214, shape=(), dtype=float32)\n",
      "tf.Tensor(1.3949434, shape=(), dtype=float32)\n",
      "tf.Tensor(1.3836498, shape=(), dtype=float32)\n",
      "tf.Tensor(1.3724408, shape=(), dtype=float32)\n",
      "tf.Tensor(1.3613163, shape=(), dtype=float32)\n",
      "tf.Tensor(1.3502765, shape=(), dtype=float32)\n",
      "tf.Tensor(1.339321, shape=(), dtype=float32)\n",
      "tf.Tensor(1.3284501, shape=(), dtype=float32)\n",
      "tf.Tensor(1.3176632, shape=(), dtype=float32)\n",
      "tf.Tensor(1.3069606, shape=(), dtype=float32)\n",
      "tf.Tensor(1.2963424, shape=(), dtype=float32)\n",
      "tf.Tensor(1.2858084, shape=(), dtype=float32)\n",
      "tf.Tensor(1.2753581, shape=(), dtype=float32)\n",
      "tf.Tensor(1.2649918, shape=(), dtype=float32)\n",
      "tf.Tensor(1.2547092, shape=(), dtype=float32)\n",
      "tf.Tensor(1.24451, shape=(), dtype=float32)\n",
      "tf.Tensor(1.2343943, shape=(), dtype=float32)\n",
      "tf.Tensor(1.2243619, shape=(), dtype=float32)\n",
      "tf.Tensor(1.2144123, shape=(), dtype=float32)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(1.2045455, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1947614, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1850595, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1754398, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1659018, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1564454, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1470703, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1377763, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1285628, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1194297, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1103768, shape=(), dtype=float32)\n",
      "tf.Tensor(1.1014035, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0925096, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0836947, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0749587, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0663007, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0577208, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0492184, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0407933, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0324446, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0241723, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0159761, shape=(), dtype=float32)\n",
      "tf.Tensor(1.0078553, shape=(), dtype=float32)\n",
      "tf.Tensor(0.99980944, shape=(), dtype=float32)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0.99183846, shape=(), dtype=float32)\n",
      "tf.Tensor(0.9839415, shape=(), dtype=float32)\n",
      "tf.Tensor(0.9761182, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "for dist_inputs in dist_dataset:\n",
    "  print(distributed_train_step(dist_inputs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jK8eQXF_q1Zs"
   },
   "source": [
    "在上面的示例中，我们通过迭代 `dist_dataset` 为训练提供输入。我们还提供 `tf.distribute.Strategy.make_experimental_numpy_dataset` 以支持 Numpy 输入。您可以在调用 `tf.distribute.Strategy.experimental_distribute_dataset` 之前使用此 API 来创建数据集。\n",
    "\n",
    "迭代数据的另一种方式是显式地使用迭代器。当您希望运行给定数量的步骤而非迭代整个数据集时，可能会用到此方式。现在可以将上面的迭代修改为：先创建迭代器，然后在迭代器上显式地调用 `next` 以获得输入数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-14T22:37:59.372915Z",
     "iopub.status.busy": "2022-12-14T22:37:59.372297Z",
     "iopub.status.idle": "2022-12-14T22:37:59.653476Z",
     "shell.execute_reply": "2022-12-14T22:37:59.652732Z"
    },
    "id": "e5BEvR0-LJAc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0.9683682, shape=(), dtype=float32)\n",
      "tf.Tensor(0.9606909, shape=(), dtype=float32)\n",
      "tf.Tensor(0.9530859, shape=(), dtype=float32)\n",
      "tf.Tensor(0.9455528, shape=(), dtype=float32)\n",
      "tf.Tensor(0.9380911, shape=(), dtype=float32)\n",
      "tf.Tensor(0.93070024, shape=(), dtype=float32)\n",
      "tf.Tensor(0.92337984, shape=(), dtype=float32)\n",
      "tf.Tensor(0.91612923, shape=(), dtype=float32)\n",
      "tf.Tensor(0.90894806, shape=(), dtype=float32)\n",
      "tf.Tensor(0.90183586, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "iterator = iter(dist_dataset)\n",
    "for _ in range(10):\n",
    "  print(distributed_train_step(next(iterator)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vDJO8mnypqBA"
   },
   "source": [
    "这涵盖了使用 `tf.distribute.Strategy` API 分布自定义训练循环的最简单情况。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BZjNwCt1qBdw"
   },
   "source": [
    "### 目前支持的策略\n",
    "\n",
    "训练 API | `MirroredStrategy` | `TPUStrategy` | `MultiWorkerMirroredStrategy` | `ParameterServerStrategy` | `CentralStorageStrategy`\n",
    ":-- | :-- | :-- | :-- | :-- | :--\n",
    "自定义训练循环 | 支持 | 支持 | 实验性支持 | 实验性支持 | 实验性支持\n",
    "\n",
    "### 示例和教程\n",
    "\n",
    "以下是一些利用自定义训练循环使用分布策略的示例：\n",
    "\n",
    "1. [教程](../tutorials/distribute/custom_training.ipynb)：使用自定义训练循环和 `MirroredStrategy` 进行训练。\n",
    "2. [教程](../tutorials/distribute/multi_worker_with_ctl.ipynb)：使用自定义训练循环和 `MultiWorkerMirroredStrategy` 进行训练。\n",
    "3. [指南](tpu.ipynb)：包含使用 `TPUStrategy` 的自定义训练循环的示例。\n",
    "4. [教程](../tutorials/distribute/parameter_server_training.ipynb)：使用自定义训练循环和 `ParameterServerStrategy` 进行参数服务器训练。\n",
    "5. 包含使用各种策略实现的最先进模型集合的 TensorFlow Model Garden [仓库](https://github.com/tensorflow/models/tree/master/official)。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Xk0JdsTHyUnE"
   },
   "source": [
    "## 其他主题\n",
    "\n",
    "本部分涵盖与多个用例相关的一些主题。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cP6BUIBtudRk"
   },
   "source": [
    "<a name=\"TF_CONFIG\"></a>\n",
    "\n",
    "### 设置 TF_CONFIG 环境变量\n",
    "\n",
    "对于多工作进程训练，如前所述，您需要为集群中运行的每个二进制文件设置 `'TF_CONFIG'` 环境变量。`'TF_CONFIG'` 环境变量是一个 JSON 字符串，它指定了哪些任务构成集群、任务的地址以及每个任务在集群中的角色。[`tensorflow/ecosystem`](https://github.com/tensorflow/ecosystem) 仓库提供了一个 Kubernetes 模板，该模板会为您的训练任务设置 `'TF_CONFIG'`。\n",
    "\n",
    "`'TF_CONFIG'` 有两个组件：集群和任务。\n",
    "\n",
    "- 集群提供有关训练集群的信息，后者是一个由不同类型的作业（例如工作进程）组成的字典。在多工作进程训练中，除了常规工作进程执行的作业之外，通常会有一个工作进程承担更多职责，例如保存检查点和为 TensorBoard 编写摘要文件。这种工作进程被称为“首席”工作进程，习惯上将索引为 `0` 的工作进程指定为首席工作进程（实际上这就是 `tf.distribute.Strategy` 的实现方式）。\n",
    "- 另一方面，任务提供有关当前任务的信息。第一个组件集群对所有工作进程都相同，第二个组件任务在每个工作进程上都不同，并指定该工作进程的类型和索引。\n",
    "\n",
    "`'TF_CONFIG'` 的示例如下：\n",
    "\n",
    "```python\n",
    "os.environ[\"TF_CONFIG\"] = json.dumps({\n",
    "    \"cluster\": {\n",
    "        \"worker\": [\"host1:port\", \"host2:port\", \"host3:port\"],\n",
    "        \"ps\": [\"host4:port\", \"host5:port\"]\n",
    "    },\n",
    "   \"task\": {\"type\": \"worker\", \"index\": 1}\n",
    "})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fezd3aF8wj9r"
   },
   "source": [
    "此 `'TF_CONFIG'` 指定 `\"cluster\"` 中有三个工作进程和两个 `\"ps\"` 任务以及它们的主机和端口。`\"task\"` 部分指定当前任务在 `\"cluster\"` 中的角色：工作进程 `1`（第二个工作进程）。集群中的有效角色是 `\"chief\"`、`\"worker\"`、`\"ps\"` 和 `\"evaluator\"`。除了使用 `tf.distribute.experimental.ParameterServerStrategy` 时，不应当存在 `\"ps\"` 作业。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GXIbqSW-sFVg"
   },
   "source": [
    "## 后续计划\n",
    "\n",
    "我们正在积极开发 `tf.distribute.Strategy`。欢迎试用，并通过 [GitHub 议题](https://github.com/tensorflow/tensorflow/issues/new)提供反馈。"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "Tce3stUlHN0L"
   ],
   "name": "distributed_training.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
